#!/usr/bin/env python3
# A5 Surface Neutrality — self-contained engine (stdlib only)
# Goal: Compare two-anchor speeds c_off and c_on = L_surf / T*_+1(OFF/ON)
#       Pass if relative drift |c_on - c_off|/c_off <= tau_c_rel (default 1e-4).
import argparse, csv, hashlib, json, math, os, random, sys, time
from pathlib import Path

# ---------- utils ----------
def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)
def sha256_of_file(p: Path):
    h = hashlib.sha256()
    with p.open('rb') as f:
        for chunk in iter(lambda: f.read(1<<20), b''):
            h.update(chunk)
    return h.hexdigest()
def sha256_of_text(s: str): return hashlib.sha256(s.encode('utf-8')).hexdigest()
def write_json(p: Path, obj): ensure_dir(p.parent); p.write_text(json.dumps(obj, indent=2), encoding='utf-8')
def write_csv(p: Path, header, rows):
    ensure_dir(p.parent)
    with p.open('w', newline='', encoding='utf-8') as f:
        w = csv.writer(f); w.writerow(header); w.writerows(rows)
def load_json(p: Path, must_exist=True):
    if not p.exists():
        if must_exist: raise FileNotFoundError(f"Missing file: {p}")
        return {}
    return json.loads(p.read_text(encoding='utf-8'))

# ---------- T* simulator & detector ----------
def simulate_reach_curve(D, tstar_true, slope_post=0.2, baseline=5.0, noise_sigma=0.4, rng=None):
    """Make a reach-vs-depth curve with a change in slope at t* (piecewise-linear)."""
    r = []
    prev = baseline
    for d in range(1, D+1):
        mean = baseline if d < tstar_true else baseline + slope_post*(d - tstar_true + 1)
        val = mean + (rng.gauss(0.0, noise_sigma) if rng else 0.0)
        # ensure nondecreasing on average
        if val < prev: val = (val + prev) / 2.0
        prev = val
        r.append(val)
    return r

def detect_tstar_slope_change(y, win=16):
    """Detect change-point via forward-minus-backward average slope (no deps)."""
    n = len(y)
    if n < 4*win + 5:
        # too short, pick middle
        return max(2, n//2)
    best_i, best_delta = 2*win+1, -1e30
    # finite differences
    dy = [y[i+1]-y[i] for i in range(n-1)]
    for i in range(2*win+1, n-2*win-1):
        back = sum(dy[i-win:i]) / float(win)
        fwd  = sum(dy[i:i+win]) / float(win)
        delta = fwd - back
        if delta > best_delta:
            best_delta = delta; best_i = i+1  # center index
    return best_i

# ---------- core ----------
def run_mode(manifest, diag, mode_name, seed_text):
    """
    Produce T* and c_pred for one mode (OFF/ON).
    We DO NOT use any angular dependence; ON-mode affects deep interior only via chi^2
    to model 'neutrality' at the surface ring.
    """
    nx = int(manifest.get('domain',{}).get('grid',{}).get('nx',256))
    ny = int(manifest.get('domain',{}).get('grid',{}).get('ny',256))
    H  = int(manifest.get('domain',{}).get('ticks',128))
    schedule = manifest.get('engine_contract',{}).get('schedule',"OFF")
    shells   = manifest.get('engine_contract',{}).get('strictness_by_shell',[3,2,2,1])
    chi      = float(manifest.get('engine_contract',{}).get('chi', 1e-3))  # only relevant if ON

    # surface ring geometry (diagnostics; independent of T*)
    inner_margin = int(diag.get('ring',{}).get('inner_margin', 8))
    outer_margin = int(diag.get('ring',{}).get('outer_margin', 8))
    ring_width   = int(diag.get('ring',{}).get('width_shells', 2))  # ± shells around the surface
    R_eff = min(nx,ny)/2.0 - outer_margin
    if R_eff <= 0: R_eff = max(nx,ny)/4.0
    L_surf = 2.0*math.pi*R_eff  # tangential span along the surface ring

    # depth horizon for T* detection
    D = int(diag.get('depth',{}).get('horizon', 4096))
    slope_post = float(diag.get('depth',{}).get('slope_post', 0.2))
    baseline   = float(diag.get('depth',{}).get('baseline', 5.0))
    noise_sigma= float(diag.get('depth',{}).get('noise_sigma', 0.4))
    win        = int(diag.get('depth',{}).get('slope_window', 16))

    # True t* baseline (OFF); ON-shift ~ O(chi^2)
    tstar_true_base = int(0.5 * D)
    epsilon = (chi*chi) * 0.1 if schedule == "ON" else 0.0  # tiny fractional shift
    tstar_true = max(10, int(round(tstar_true_base * (1.0 + epsilon))))

    # Reproducible RNG
    rng_seed = int(sha256_of_text(f"{seed_text}|{mode_name}|{schedule}|{chi}|{nx}x{ny}|{H}|{D}")[:8], 16)
    rng = random.Random(rng_seed)

    # Generate curve and detect T*
    y = simulate_reach_curve(D, tstar_true, slope_post, baseline, noise_sigma, rng=rng)
    tstar_est = detect_tstar_slope_change(y, win=win)

    c_pred = L_surf / float(tstar_est)  # two-anchor speed (native units)

    # Emit per-mode metrics row
    mode_row = [
        mode_name, schedule, chi, nx, ny, H, D, R_eff, L_surf, tstar_est, c_pred
    ]
    return {
        "row": mode_row,
        "L_surf": L_surf,
        "tstar_est": tstar_est,
        "c_pred": c_pred,
        "rng_seed": rng_seed
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--manifest_off', required=True)  # JSON
    ap.add_argument('--manifest_on', required=True)   # JSON
    ap.add_argument('--diag', required=True)          # JSON
    ap.add_argument('--out', required=True)
    args = ap.parse_args()

    out_dir = Path(args.out)
    metrics_dir = out_dir/'metrics'
    audits_dir  = out_dir/'audits'
    runinfo_dir = out_dir/'run_info'
    for d in [metrics_dir, audits_dir, runinfo_dir]: ensure_dir(d)

    # Load configs
    m_off = load_json(Path(args.manifest_off), must_exist=True)
    m_on  = load_json(Path(args.manifest_on),  must_exist=True)
    diag  = load_json(Path(args.diag),         must_exist=True)

    # Hashes
    hashes = {
        "manifest_off_hash": sha256_of_file(Path(args.manifest_off)),
        "manifest_on_hash":  sha256_of_file(Path(args.manifest_on)),
        "diag_hash":         sha256_of_file(Path(args.diag))
    }

    # Tolerance
    tau_c_rel = float(diag.get('tolerances',{}).get('tau_c_rel', 1e-4))

    # Run OFF and ON
    off = run_mode(m_off, diag, "OFF", "A5")
    on  = run_mode(m_on,  diag, "ON",  "A5")

    # Compare
    c_off, c_on = off["c_pred"], on["c_pred"]
    delta_rel = abs(c_on - c_off) / c_off if c_off != 0 else float('inf')
    PASS = (delta_rel <= tau_c_rel)

    # Metrics CSV
    write_csv(
        metrics_dir/'surface_neutrality_modes.csv',
        ['mode','schedule','chi','nx','ny','H','D','R_eff','L_surf','tstar_est','c_pred'],
        [off["row"], on["row"]]
    )

    # Audit JSON
    write_json(
        audits_dir/'surface_neutrality.json',
        {
            "tau_c_rel": tau_c_rel,
            "c_off": c_off,
            "c_on": c_on,
            "delta_c_rel": delta_rel,
            "tstar_off": off["tstar_est"],
            "tstar_on": on["tstar_est"],
            "L_surf": off["L_surf"],  # same geometry used
            "rng_seeds": {"off": off["rng_seed"], "on": on["rng_seed"]},
            "PASS": PASS
        }
    )

    # Provenance
    write_json(
        runinfo_dir/'hashes.json',
        {
            **hashes,
            "engine_entrypoint": f"python {Path(sys.argv[0]).name} --manifest_off <...> --manifest_on <...> --diag <...> --out <...>"
        }
    )

    # stdout
    summary = {
        "c_off": round(c_off, 8),
        "c_on": round(c_on, 8),
        "delta_c_rel": delta_rel,
        "tau_c_rel": tau_c_rel,
        "PASS": PASS,
        "audit_path": str((audits_dir/'surface_neutrality.json').as_posix())
    }
    print("A5 SUMMARY:", json.dumps(summary))

if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        # explicit failure with reason
        try:
            out_dir = None
            for i,a in enumerate(sys.argv):
                if a == '--out' and i+1 < len(sys.argv): out_dir = Path(sys.argv[i+1])
            if out_dir:
                audits = out_dir/'audits'; ensure_dir(audits)
                write_json(audits/'surface_neutrality.json',
                           {"PASS": False, "failure_reason": f"Unexpected error: {type(e).__name__}: {e}"})
        finally:
            raise